# pylint: disable=invalid-name,too-many-locals,line-too-long,no-else-raise,too-many-arguments,no-self-use,too-many-statements,stop-iteration-return,import-outside-toplevel
import typing

# The PyTorch portions of this code are subject to the following copyright notice.
# Copyright (c) 2019-present NAVER Corp.

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

import cv2
import numpy as np
import tensorflow as tf
import efficientnet.tfkeras as efficientnet
from tensorflow import keras

from . import tools


def compute_input(image):
    # should be RGB order
    image = image.astype("float32")
    mean = np.array([0.485, 0.456, 0.406])
    variance = np.array([0.229, 0.224, 0.225])

    image -= mean * 255
    image /= variance * 255
    return image


def invert_input(X):
    X = X.copy()
    mean = np.array([0.485, 0.456, 0.406])
    variance = np.array([0.229, 0.224, 0.225])

    X *= variance * 255
    X += mean * 255
    return X.clip(0, 255).astype("uint8")


def get_gaussian_heatmap(size=512, distanceRatio=3.34):
    v = np.abs(np.linspace(-size / 2, size / 2, num=size))
    x, y = np.meshgrid(v, v)
    g = np.sqrt(x**2 + y**2)
    g *= distanceRatio / (size / 2)
    g = np.exp(-(1 / 2) * (g**2))
    g *= 255
    return g.clip(0, 255).astype("uint8")


def upconv(x, n, filters):
    x = keras.layers.Conv2D(
        filters=filters, kernel_size=1, strides=1, name=f"upconv{n}.conv.0"
    )(x)
    x = keras.layers.BatchNormalization(
        epsilon=1e-5, momentum=0.9, name=f"upconv{n}.conv.1"
    )(x)
    x = keras.layers.Activation("relu", name=f"upconv{n}.conv.2")(x)
    x = keras.layers.Conv2D(
        filters=filters // 2,
        kernel_size=3,
        strides=1,
        padding="same",
        name=f"upconv{n}.conv.3",
    )(x)
    x = keras.layers.BatchNormalization(
        epsilon=1e-5, momentum=0.9, name=f"upconv{n}.conv.4"
    )(x)
    x = keras.layers.Activation("relu", name=f"upconv{n}.conv.5")(x)
    return x


def make_vgg_block(x, filters, n, prefix, pooling=True):
    x = keras.layers.Conv2D(
        filters=filters,
        strides=(1, 1),
        kernel_size=(3, 3),
        padding="same",
        name=f"{prefix}.{n}",
    )(x)
    x = keras.layers.BatchNormalization(
        momentum=0.1, epsilon=1e-5, axis=-1, name=f"{prefix}.{n+1}"
    )(x)
    x = keras.layers.Activation("relu", name=f"{prefix}.{n+2}")(x)
    if pooling:
        x = keras.layers.MaxPooling2D(
            pool_size=(2, 2), padding="valid", strides=(2, 2), name=f"{prefix}.{n+3}"
        )(x)
    return x


def compute_maps(heatmap, image_height, image_width, lines):
    assert image_height % 2 == 0, "Height must be an even number"
    assert image_width % 2 == 0, "Width must be an even number"

    textmap = np.zeros((image_height // 2, image_width // 2)).astype("float32")
    linkmap = np.zeros((image_height // 2, image_width // 2)).astype("float32")

    src = np.array(
        [
            [0, 0],
            [heatmap.shape[1], 0],
            [heatmap.shape[1], heatmap.shape[0]],
            [0, heatmap.shape[0]],
        ]
    ).astype("float32")

    for line in lines:
        line, orientation = tools.fix_line(line)
        previous_link_points = None
        for [(x1, y1), (x2, y2), (x3, y3), (x4, y4)], c in line:
            x1, y1, x2, y2, x3, y3, x4, y4 = map(
                lambda v: max(v, 0), [x1, y1, x2, y2, x3, y3, x4, y4]
            )
            if c == " ":
                previous_link_points = None
                continue
            yc = (y4 + y1 + y3 + y2) / 4
            xc = (x1 + x2 + x3 + x4) / 4
            if orientation == "horizontal":
                current_link_points = (
                    np.array(
                        [
                            [(xc + (x1 + x2) / 2) / 2, (yc + (y1 + y2) / 2) / 2],
                            [(xc + (x3 + x4) / 2) / 2, (yc + (y3 + y4) / 2) / 2],
                        ]
                    )
                    / 2
                )
            else:
                current_link_points = (
                    np.array(
                        [
                            [(xc + (x1 + x4) / 2) / 2, (yc + (y1 + y4) / 2) / 2],
                            [(xc + (x2 + x3) / 2) / 2, (yc + (y2 + y3) / 2) / 2],
                        ]
                    )
                    / 2
                )
            character_points = (
                np.array([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]).astype("float32") / 2
            )
            # pylint: disable=unsubscriptable-object
            if previous_link_points is not None:
                if orientation == "horizontal":
                    link_points = np.array(
                        [
                            previous_link_points[0],
                            current_link_points[0],
                            current_link_points[1],
                            previous_link_points[1],
                        ]
                    )
                else:
                    link_points = np.array(
                        [
                            previous_link_points[0],
                            previous_link_points[1],
                            current_link_points[1],
                            current_link_points[0],
                        ]
                    )
                ML = cv2.getPerspectiveTransform(
                    src=src,
                    dst=link_points.astype("float32"),
                )
                linkmap += cv2.warpPerspective(
                    heatmap, ML, dsize=(linkmap.shape[1], linkmap.shape[0])
                ).astype("float32")
            MA = cv2.getPerspectiveTransform(
                src=src,
                dst=character_points,
            )
            textmap += cv2.warpPerspective(
                heatmap, MA, dsize=(textmap.shape[1], textmap.shape[0])
            ).astype("float32")
            # pylint: enable=unsubscriptable-object
            previous_link_points = current_link_points
    return (
        np.concatenate(
            [textmap[..., np.newaxis], linkmap[..., np.newaxis]], axis=2
        ).clip(0, 255)
        / 255
    )


def map_to_rgb(y):
    return (
        np.concatenate([y, np.zeros((y.shape[0], y.shape[1], 1))], axis=-1) * 255
    ).astype("uint8")


def getBoxes(
    y_pred,
    detection_threshold=0.7,
    text_threshold=0.4,
    link_threshold=0.4,
    size_threshold=10,
):
    box_groups = []
    for y_pred_cur in y_pred:
        # Prepare data
        textmap = y_pred_cur[..., 0].copy()
        linkmap = y_pred_cur[..., 1].copy()
        img_h, img_w = textmap.shape

        _, text_score = cv2.threshold(
            textmap, thresh=text_threshold, maxval=1, type=cv2.THRESH_BINARY
        )
        _, link_score = cv2.threshold(
            linkmap, thresh=link_threshold, maxval=1, type=cv2.THRESH_BINARY
        )
        n_components, labels, stats, _ = cv2.connectedComponentsWithStats(
            np.clip(text_score + link_score, 0, 1).astype("uint8"), connectivity=4
        )
        boxes = []
        for component_id in range(1, n_components):
            # Filter by size
            size = stats[component_id, cv2.CC_STAT_AREA]

            if size < size_threshold:
                continue

            # If the maximum value within this connected component is less than
            # text threshold, we skip it.
            if np.max(textmap[labels == component_id]) < detection_threshold:
                continue

            # Make segmentation map. It is 255 where we find text, 0 otherwise.
            segmap = np.zeros_like(textmap)
            segmap[labels == component_id] = 255
            segmap[np.logical_and(link_score, text_score)] = 0
            x, y, w, h = [
                stats[component_id, key]
                for key in [
                    cv2.CC_STAT_LEFT,
                    cv2.CC_STAT_TOP,
                    cv2.CC_STAT_WIDTH,
                    cv2.CC_STAT_HEIGHT,
                ]
            ]

            # Expand the elements of the segmentation map
            niter = int(np.sqrt(size * min(w, h) / (w * h)) * 2)
            sx, sy = max(x - niter, 0), max(y - niter, 0)
            ex, ey = min(x + w + niter + 1, img_w), min(y + h + niter + 1, img_h)
            segmap[sy:ey, sx:ex] = cv2.dilate(
                segmap[sy:ey, sx:ex],
                cv2.getStructuringElement(cv2.MORPH_RECT, (1 + niter, 1 + niter)),
            )

            # Make rotated box from contour
            contours = cv2.findContours(
                segmap.astype("uint8"),
                mode=cv2.RETR_TREE,
                method=cv2.CHAIN_APPROX_SIMPLE,
            )[-2]
            contour = contours[0]
            box = cv2.boxPoints(cv2.minAreaRect(contour))

            # Check to see if we have a diamond
            w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
            box_ratio = max(w, h) / (min(w, h) + 1e-5)
            if abs(1 - box_ratio) <= 0.1:
                l, r = contour[:, 0, 0].min(), contour[:, 0, 0].max()
                t, b = contour[:, 0, 1].min(), contour[:, 0, 1].max()
                box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
            else:
                # Make clock-wise order
                box = np.array(np.roll(box, 4 - box.sum(axis=1).argmin(), 0))
            boxes.append(2 * box)
        box_groups.append(np.array(boxes))
    return box_groups


class UpsampleLike(keras.layers.Layer):
    """Keras layer for upsampling a Tensor to be the same shape as another Tensor."""

    # pylint:disable=unused-argument
    def call(self, inputs, **kwargs):
        source, target = inputs
        target_shape = keras.backend.shape(target)
        if keras.backend.image_data_format() == "channels_first":
            raise NotImplementedError
        else:
            # pylint: disable=no-member
            return tf.compat.v1.image.resize_bilinear(
                source, size=(target_shape[1], target_shape[2]), half_pixel_centers=True
            )

    def compute_output_shape(self, input_shape):
        if keras.backend.image_data_format() == "channels_first":
            raise NotImplementedError
        else:
            return (input_shape[0][0],) + input_shape[1][1:3] + (input_shape[0][-1],)


def build_vgg_backbone(inputs):
    x = make_vgg_block(inputs, filters=64, n=0, pooling=False, prefix="basenet.slice1")
    x = make_vgg_block(x, filters=64, n=3, pooling=True, prefix="basenet.slice1")
    x = make_vgg_block(x, filters=128, n=7, pooling=False, prefix="basenet.slice1")
    x = make_vgg_block(x, filters=128, n=10, pooling=True, prefix="basenet.slice1")
    x = make_vgg_block(x, filters=256, n=14, pooling=False, prefix="basenet.slice2")
    x = make_vgg_block(x, filters=256, n=17, pooling=False, prefix="basenet.slice2")
    x = make_vgg_block(x, filters=256, n=20, pooling=True, prefix="basenet.slice3")
    x = make_vgg_block(x, filters=512, n=24, pooling=False, prefix="basenet.slice3")
    x = make_vgg_block(x, filters=512, n=27, pooling=False, prefix="basenet.slice3")
    x = make_vgg_block(x, filters=512, n=30, pooling=True, prefix="basenet.slice4")
    x = make_vgg_block(x, filters=512, n=34, pooling=False, prefix="basenet.slice4")
    x = make_vgg_block(x, filters=512, n=37, pooling=False, prefix="basenet.slice4")
    x = make_vgg_block(x, filters=512, n=40, pooling=True, prefix="basenet.slice4")
    vgg = keras.models.Model(inputs=inputs, outputs=x)
    return [
        vgg.get_layer(slice_name).output
        for slice_name in [
            "basenet.slice1.12",
            "basenet.slice2.19",
            "basenet.slice3.29",
            "basenet.slice4.38",
        ]
    ]


def build_efficientnet_backbone(inputs, backbone_name, imagenet):
    backbone = getattr(efficientnet, backbone_name)(
        include_top=False, input_tensor=inputs, weights="imagenet" if imagenet else None
    )
    return [
        backbone.get_layer(slice_name).output
        for slice_name in [
            "block2a_expand_activation",
            "block3a_expand_activation",
            "block4a_expand_activation",
            "block5a_expand_activation",
        ]
    ]


def build_keras_model(weights_path: str = None, backbone_name="vgg"):
    inputs = keras.layers.Input((None, None, 3))

    if backbone_name == "vgg":
        s1, s2, s3, s4 = build_vgg_backbone(inputs)
    elif "efficientnet" in backbone_name.lower():
        s1, s2, s3, s4 = build_efficientnet_backbone(
            inputs=inputs, backbone_name=backbone_name, imagenet=weights_path is None
        )
    else:
        raise NotImplementedError

    s5 = keras.layers.MaxPooling2D(
        pool_size=3, strides=1, padding="same", name="basenet.slice5.0"
    )(s4)
    s5 = keras.layers.Conv2D(
        1024,
        kernel_size=(3, 3),
        padding="same",
        strides=1,
        dilation_rate=6,
        name="basenet.slice5.1",
    )(s5)
    s5 = keras.layers.Conv2D(
        1024, kernel_size=1, strides=1, padding="same", name="basenet.slice5.2"
    )(s5)

    y = keras.layers.Concatenate()([s5, s4])
    y = upconv(y, n=1, filters=512)
    y = UpsampleLike()([y, s3])
    y = keras.layers.Concatenate()([y, s3])
    y = upconv(y, n=2, filters=256)
    y = UpsampleLike()([y, s2])
    y = keras.layers.Concatenate()([y, s2])
    y = upconv(y, n=3, filters=128)
    y = UpsampleLike()([y, s1])
    y = keras.layers.Concatenate()([y, s1])
    features = upconv(y, n=4, filters=64)

    y = keras.layers.Conv2D(
        filters=32, kernel_size=3, strides=1, padding="same", name="conv_cls.0"
    )(features)
    y = keras.layers.Activation("relu", name="conv_cls.1")(y)
    y = keras.layers.Conv2D(
        filters=32, kernel_size=3, strides=1, padding="same", name="conv_cls.2"
    )(y)
    y = keras.layers.Activation("relu", name="conv_cls.3")(y)
    y = keras.layers.Conv2D(
        filters=16, kernel_size=3, strides=1, padding="same", name="conv_cls.4"
    )(y)
    y = keras.layers.Activation("relu", name="conv_cls.5")(y)
    y = keras.layers.Conv2D(
        filters=16, kernel_size=1, strides=1, padding="same", name="conv_cls.6"
    )(y)
    y = keras.layers.Activation("relu", name="conv_cls.7")(y)
    y = keras.layers.Conv2D(
        filters=2, kernel_size=1, strides=1, padding="same", name="conv_cls.8"
    )(y)
    if backbone_name != "vgg":
        y = keras.layers.Activation("sigmoid")(y)
    model = keras.models.Model(inputs=inputs, outputs=y)
    if weights_path is not None:
        if weights_path.endswith(".h5"):
            model.load_weights(weights_path)
        elif weights_path.endswith(".pth"):
            assert (
                backbone_name == "vgg"
            ), "PyTorch weights only allowed with VGG backbone."
            load_torch_weights(model=model, weights_path=weights_path)
        else:
            raise NotImplementedError(f"Cannot load weights from {weights_path}")
    return model


# pylint: disable=import-error
def load_torch_weights(model, weights_path):
    import torch

    pretrained = torch.load(weights_path, map_location=torch.device("cpu"))
    layer_names = list(
        set(
            ".".join(k.split(".")[1:-1])
            for k in pretrained.keys()
            if k.split(".")[-1] != "num_batches_tracked"
        )
    )
    for layer_name in layer_names:
        try:
            layer = model.get_layer(layer_name)
        except Exception:  # pylint: disable=broad-except
            print("Skipping", layer.name)
            continue
        if isinstance(layer, keras.layers.BatchNormalization):
            gamma, beta, running_mean, running_std = [
                pretrained[k].numpy()
                for k in [
                    f"module.{layer_name}.weight",
                    f"module.{layer_name}.bias",
                    f"module.{layer_name}.running_mean",
                    f"module.{layer_name}.running_var",
                ]
            ]
            layer.set_weights([gamma, beta, running_mean, running_std])
        elif isinstance(layer, keras.layers.Conv2D):
            weights, bias = [
                pretrained[k].numpy()
                for k in [f"module.{layer_name}.weight", f"module.{layer_name}.bias"]
            ]
            layer.set_weights([weights.transpose(2, 3, 1, 0), bias])

        else:
            raise NotImplementedError

    for layer in model.layers:
        if isinstance(layer, (keras.layers.BatchNormalization, keras.layers.Conv2D)):
            assert layer.name in layer_names


# pylint: disable=import-error,too-few-public-methods
def build_torch_model(weights_path=None):
    from collections import namedtuple, OrderedDict

    import torch
    import torchvision

    def init_weights(modules):
        for m in modules:
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.xavier_uniform_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, torch.nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, torch.nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

    class vgg16_bn(torch.nn.Module):
        def __init__(self, pretrained=True, freeze=True):
            super().__init__()
            # We don't bother loading the pretrained VGG
            # because we're going to use the weights
            # at weights_path.
            vgg_pretrained_features = torchvision.models.vgg16_bn(
                pretrained=False
            ).features
            self.slice1 = torch.nn.Sequential()
            self.slice2 = torch.nn.Sequential()
            self.slice3 = torch.nn.Sequential()
            self.slice4 = torch.nn.Sequential()
            self.slice5 = torch.nn.Sequential()
            for x in range(12):  # conv2_2
                self.slice1.add_module(str(x), vgg_pretrained_features[x])
            for x in range(12, 19):  # conv3_3
                self.slice2.add_module(str(x), vgg_pretrained_features[x])
            for x in range(19, 29):  # conv4_3
                self.slice3.add_module(str(x), vgg_pretrained_features[x])
            for x in range(29, 39):  # conv5_3
                self.slice4.add_module(str(x), vgg_pretrained_features[x])

            # fc6, fc7 without atrous conv
            self.slice5 = torch.nn.Sequential(
                torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
                torch.nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
                torch.nn.Conv2d(1024, 1024, kernel_size=1),
            )

            if not pretrained:
                init_weights(self.slice1.modules())
                init_weights(self.slice2.modules())
                init_weights(self.slice3.modules())
                init_weights(self.slice4.modules())

            init_weights(self.slice5.modules())  # no pretrained model for fc6 and fc7

            if freeze:
                for param in self.slice1.parameters():  # only first conv
                    param.requires_grad = False

        def forward(self, X):  # pylint: disable=arguments-differ
            h = self.slice1(X)
            h_relu2_2 = h
            h = self.slice2(h)
            h_relu3_2 = h
            h = self.slice3(h)
            h_relu4_3 = h
            h = self.slice4(h)
            h_relu5_3 = h
            h = self.slice5(h)
            h_fc7 = h
            vgg_outputs = namedtuple(
                "vgg_outputs", ["fc7", "relu5_3", "relu4_3", "relu3_2", "relu2_2"]
            )
            out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
            return out

    class double_conv(torch.nn.Module):
        def __init__(self, in_ch, mid_ch, out_ch):
            super().__init__()
            self.conv = torch.nn.Sequential(
                torch.nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
                torch.nn.BatchNorm2d(mid_ch),
                torch.nn.ReLU(inplace=True),
                torch.nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
                torch.nn.BatchNorm2d(out_ch),
                torch.nn.ReLU(inplace=True),
            )

        def forward(self, x):  # pylint: disable=arguments-differ
            x = self.conv(x)
            return x

    class CRAFT(torch.nn.Module):
        def __init__(self, pretrained=False, freeze=False):
            super().__init__()
            # Base network
            self.basenet = vgg16_bn(pretrained, freeze)
            # U network
            self.upconv1 = double_conv(1024, 512, 256)
            self.upconv2 = double_conv(512, 256, 128)
            self.upconv3 = double_conv(256, 128, 64)
            self.upconv4 = double_conv(128, 64, 32)

            num_class = 2
            self.conv_cls = torch.nn.Sequential(
                torch.nn.Conv2d(32, 32, kernel_size=3, padding=1),
                torch.nn.ReLU(inplace=True),
                torch.nn.Conv2d(32, 32, kernel_size=3, padding=1),
                torch.nn.ReLU(inplace=True),
                torch.nn.Conv2d(32, 16, kernel_size=3, padding=1),
                torch.nn.ReLU(inplace=True),
                torch.nn.Conv2d(16, 16, kernel_size=1),
                torch.nn.ReLU(inplace=True),
                torch.nn.Conv2d(16, num_class, kernel_size=1),
            )

            init_weights(self.upconv1.modules())
            init_weights(self.upconv2.modules())
            init_weights(self.upconv3.modules())
            init_weights(self.upconv4.modules())
            init_weights(self.conv_cls.modules())

        def forward(self, x):  # pylint: disable=arguments-differ
            # Base network
            sources = self.basenet(x)
            # U network
            # pylint: disable=E1101
            y = torch.cat([sources[0], sources[1]], dim=1)

            y = self.upconv1(y)

            y = torch.nn.functional.interpolate(
                y, size=sources[2].size()[2:], mode="bilinear", align_corners=False
            )
            y = torch.cat([y, sources[2]], dim=1)
            y = self.upconv2(y)

            y = torch.nn.functional.interpolate(
                y, size=sources[3].size()[2:], mode="bilinear", align_corners=False
            )
            y = torch.cat([y, sources[3]], dim=1)
            y = self.upconv3(y)

            y = torch.nn.functional.interpolate(
                y, size=sources[4].size()[2:], mode="bilinear", align_corners=False
            )
            y = torch.cat([y, sources[4]], dim=1)
            # pylint: enable=E1101
            feature = self.upconv4(y)

            y = self.conv_cls(feature)

            return y.permute(0, 2, 3, 1), feature

    def copyStateDict(state_dict):
        if list(state_dict.keys())[0].startswith("module"):
            start_idx = 1
        else:
            start_idx = 0
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = ".".join(k.split(".")[start_idx:])
            new_state_dict[name] = v
        return new_state_dict

    model = CRAFT(pretrained=True).eval()
    if weights_path is not None:
        model.load_state_dict(
            copyStateDict(torch.load(weights_path, map_location=torch.device("cpu")))
        )
    return model


PRETRAINED_WEIGHTS = {
    ("clovaai_general", True): {
        "url": "https://github.com/faustomorales/keras-ocr/releases/download/v0.8.4/craft_mlt_25k.pth",
        "filename": "craft_mlt_25k.pth",
        "sha256": "4a5efbfb48b4081100544e75e1e2b57f8de3d84f213004b14b85fd4b3748db17",
    },
    ("clovaai_general", False): {
        "url": "https://github.com/faustomorales/keras-ocr/releases/download/v0.8.4/craft_mlt_25k.h5",
        "filename": "craft_mlt_25k.h5",
        "sha256": "7283ce2ff05a0617e9740c316175ff3bacdd7215dbdf1a726890d5099431f899",
    },
}


class Detector:
    """A text detector using the CRAFT architecture.

    Args:
        weights: The weights to use for the model. Currently, only `clovaai_general`
            is supported.
        load_from_torch: Whether to load the weights from the original PyTorch weights.
        optimizer: The optimizer to use for training the model.
        backbone_name: The backbone to use. Currently, only 'vgg' is supported.
    """

    def __init__(
        self,
        weights="clovaai_general",
        load_from_torch=False,
        optimizer="adam",
        backbone_name="vgg",
    ):
        if weights is not None:
            pretrained_key = (weights, load_from_torch)
            assert backbone_name == "vgg", "Pretrained weights available only for VGG."
            assert (
                pretrained_key in PRETRAINED_WEIGHTS
            ), "Selected weights configuration not found."
            weights_config = PRETRAINED_WEIGHTS[pretrained_key]
            weights_path = tools.download_and_verify(
                url=weights_config["url"],
                filename=weights_config["filename"],
                sha256=weights_config["sha256"],
            )
        else:
            weights_path = None
        self.model = build_keras_model(
            weights_path=weights_path, backbone_name=backbone_name
        )
        self.model.compile(loss="mse", optimizer=optimizer)

    def get_batch_generator(
        self,
        image_generator,
        batch_size=8,
        heatmap_size=512,
        heatmap_distance_ratio=1.5,
    ):
        """Get a generator of X, y batches to train the detector.

        Args:
            image_generator: A generator with the same signature as
                keras_ocr.tools.get_image_generator. Optionally, a third
                entry in the tuple (beyond image and lines) can be provided
                which will be interpreted as the sample weight.
            batch_size: The size of batches to generate.
            heatmap_size: The size of the heatmap to pass to get_gaussian_heatmap
            heatmap_distance_ratio: The distance ratio to pass to
                get_gaussian_heatmap. The larger the value, the more tightly
                concentrated the heatmap becomes.
        """
        heatmap = get_gaussian_heatmap(
            size=heatmap_size, distanceRatio=heatmap_distance_ratio
        )
        while True:
            batch = [next(image_generator) for n in range(batch_size)]
            images = np.array([entry[0] for entry in batch])
            line_groups = [entry[1] for entry in batch]
            X = compute_input(images)
            # pylint: disable=unsubscriptable-object
            y = np.array(
                [
                    compute_maps(
                        heatmap=heatmap,
                        image_height=images.shape[1],
                        image_width=images.shape[2],
                        lines=lines,
                    )
                    for lines in line_groups
                ]
            )
            # pylint: enable=unsubscriptable-object
            if len(batch[0]) == 3:
                sample_weights = np.array([sample[2] for sample in batch])
                yield X, y, sample_weights
            else:
                yield X, y

    def detect(
        self,
        images: typing.List[typing.Union[np.ndarray, str]],
        detection_threshold=0.7,
        text_threshold=0.4,
        link_threshold=0.4,
        size_threshold=10,
        **kwargs,
    ):
        """Recognize the text in a set of images.

        Args:
            images: Can be a list of numpy arrays of shape HxWx3 or a list of
                filepaths.
            link_threshold: This is the same as `text_threshold`, but is applied to the
                link map instead of the text map.
            detection_threshold: We want to avoid including boxes that may have
                represented large regions of low confidence text predictions. To do this,
                we do a final check for each word box to make sure the maximum confidence
                value exceeds some detection threshold. This is the threshold used for
                this check.
            text_threshold: When the text map is processed, it is converted from confidence
                (float from zero to one) values to classification (0 for not text, 1 for
                text) using binary thresholding. The threshold value determines the
                breakpoint at which a value is converted to a 1 or a 0. For example, if
                the threshold is 0.4 and a value for particular point on the text map is
                0.5, that value gets converted to a 1. The higher this value is, the less
                likely it is that characters will be merged together into a single word.
                The lower this value is, the more likely it is that non-text will be detected.
                Therein lies the balance.
            size_threshold: The minimum area for a word.
        """
        images = [compute_input(tools.read(image)) for image in images]
        boxes = getBoxes(
            self.model.predict(np.array(images), **kwargs),
            detection_threshold=detection_threshold,
            text_threshold=text_threshold,
            link_threshold=link_threshold,
            size_threshold=size_threshold,
        )
        return boxes
